#!/usr/bin/env python

import argparse
import kalibr_common as kc
import math
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import sm


def parse_arguments():
    parser = argparse.ArgumentParser(
        description=
        'Visualize distortion of a camchain file.')
    parser.add_argument(
        '--cam',
        dest='chain_yaml',
        help=
        'Camera configuration as yaml file.',
        required=True)
    parser.add_argument(
        '--number_of_rectangles',
        help="Number of rectangles to draw.",
        type=int,
        default=5)
    parser.add_argument(
        '--line_thickness',
        help="Line thickness.",
        type=int,
        default=3)
    parsed_args = parser.parse_args()
    return parsed_args

def normalize((u,v), focal_lengths, principal_point):
    x = (u - principal_point[1]) / focal_lengths[1]
    y = (v - principal_point[0]) / focal_lengths[0]
    return (x,y)

def denormalize((x,y), focal_lengths, principal_point):
    u = x * focal_lengths[1] + principal_point[1]
    v = y * focal_lengths[0] + principal_point[0]
    return (u,v)

def distortPoint((x,y), distortion_model, distortion_coeffs):
    # Distortes a point regarding the distortion model. The Input point has to
    # be in normalized coordinates.
    if distortion_model == "radial-tangential" or distortion_model == "radtan":
        x2 = x**2
        y2 = y**2
        xy = x * y
        rho2 = x2 + y2
        rad_dist = distortion_coeffs[0] * rho2 + distortion_coeffs[1] * rho2**2
        x += x * rad_dist + 2.0 * distortion_coeffs[2] * xy + \
            distortion_coeffs[3] * (rho2 + 2.0 * x2)
        y += y * rad_dist + 2.0 * distortion_coeffs[2] * xy + \
            distortion_coeffs[3] * (rho2 + 2.0 * y2)
        return (x, y) 
    elif distortion_model == "equidistant" or distortion_model == "equi":
        r = math.sqrt(x**2 + y**2)
        if r < 1e-8:
          return (x,y)
        theta = math.atan(r)
        thetad = theta * (1 + distortion_coeffs[0] * theta**2  + \
            distortion_coeffs[1] * theta**4 + distortion_coeffs[2] * \
            theta**6 + distortion_coeffs[3] * theta**8)
        scaling = thetad / r
        return (scaling * x, scaling * y)
    elif distortion_model ==  "fov":
        w = distortion_coeffs[0]
        r_u = np.linalg.norm((x,y))
        tanwhalf = math.tan(w / 2)
        atan_wrd = math.atan(2. * tanwhalf * r_u)
        
        r_rd = 0.0
        if w * w < 1e-5:
          r_rd = 1.0
        else:
          if r_u * r_u < 1e-5:
            r_rd = 2. * tanwhalf / w
          else:
            r_rd = atan_wrd / (r_u * w)
        return (x * r_rd, y * r_rd)
    else:
      raise Exception('Distortion model ' + distortion_model + ' not found.')

def distort(image, distortion_model, distortion_coeffs, focal_lengths, 
        principal_point):
    distorted_image = np.zeros(image.shape)
    for i in range(0, image.shape[0]):
      for j in range(0, image.shape[1]):
        normalized_point = normalize((i, j) ,focal_lengths, principal_point)
        distorted_normalized = distortPoint(normalized_point, distortion_model,
            distortion_coeffs)
        distorted = denormalize(distorted_normalized, focal_lengths, 
            principal_point)
        if (int(distorted[0]) < 0 or int(distorted[0]) > image.shape[0] or
            int(distorted[1]) < 0 or int(distorted[1]) > image.shape[1]):
          continue
        distorted_image[int(distorted[0]), int(distorted[1])] = image[i,j]
    return distorted_image

def drawImage(resolution, num_rectangles, line_thickness):
    assert num_rectangles > 0
    assert line_thickness > 0
    raw_image = np.zeros((resolution[1], resolution[0]))
    for i in range(1, num_rectangles + 1):
      h_offset = i*resolution[0] / 2 / (num_rectangles + 1) 
      v_offset = i*resolution[1] / 2 / (num_rectangles + 1)
      # Draw lines.
      for t in range(0, line_thickness):
        for i in range(-h_offset, h_offset):
            raw_image[resolution[1] / 2 - v_offset - t, 
                resolution[0] / 2 + i] = 255;
            raw_image[resolution[1] / 2 + v_offset + t,
                resolution[0] / 2 + i] = 255;
      for t in range(0, line_thickness):
        for j in range(-v_offset, v_offset):
            raw_image[resolution[1] / 2 + j,
                resolution[0] / 2 - h_offset - t] = 255
            raw_image[resolution[1] / 2 + j, 
                resolution[0] / 2 + h_offset + t] = 255
    return raw_image

def imshow(image, color, alpha = 1):
    plt.imshow(image, cmap = color,  interpolation = 'nearest', alpha=alpha)
    plt.axis('equal')
    plt.axis([0, image.shape[1], 0, image.shape[0]])

def view_distortions():
    parsed_args = parse_arguments()
    camchain = kc.ConfigReader.CameraChainParameters(parsed_args.chain_yaml)
    num_cameras = camchain.numCameras()
    print("Plotting distortion with " + str(num_cameras) + " cameras.")
    for cam_index in range(num_cameras):
        cam = 'cam{0}'.format(cam_index)
        intrinsics = np.asarray(camchain.data[cam]['intrinsics'])
        resolution = np.asarray(camchain.data[cam]['resolution'])
        distortion_coeffs = np.asarray(
            camchain.data[cam]['distortion_coeffs'])
        distortion_model = camchain.data[cam]['distortion_model']
        focal_lengths = intrinsics[0:2]
        principal_point = intrinsics[2:4]
        plt.subplot(num_cameras, 1 , cam_index + 1)
        raw_image = drawImage(resolution, parsed_args.number_of_rectangles,
            parsed_args.line_thickness)
        imshow(raw_image, 'Blues')
        distorted_image = distort(raw_image, distortion_model, 
            distortion_coeffs, focal_lengths, principal_point)
        imshow(distorted_image, 'Reds', alpha=0.5)
        plt.title(cam)
    plt.show()

if __name__ == "__main__":
    view_distortions()
